{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp common._base_recurrent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BaseRecurrent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> The `BaseRecurrent` class contains standard methods shared across recurrent neural networks; these models possess the ability to process variable-length sequences of inputs through their internal memory states. The class is represented by `LSTM`, `GRU`, and `RNN`, along with other more sophisticated architectures like `MQCNN`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The standard methods include `TemporalNorm` preprocessing, optimization utilities like parameter initialization, `training_step`, `validation_step`, and shared `fit` and `predict` methods.These shared methods enable all the `neuralforecast.models` compatibility with the `core.NeuralForecast` wrapper class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import random\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
    "\n",
    "from neuralforecast.common._scalers import TemporalNorm\n",
    "from neuralforecast.tsdataset import TimeSeriesDataModule\n",
    "from neuralforecast.utils import get_indexer_raise_missing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class BaseRecurrent(pl.LightningModule):\n",
    "    \"\"\" Base Recurrent\n",
    "    \n",
    "    Base class for all recurrent-based models. The forecasts are produced sequentially between \n",
    "    windows.\n",
    "    \n",
    "    This class implements the basic functionality for all windows-based models, including:\n",
    "    - PyTorch Lightning's methods training_step, validation_step, predict_step. <br>\n",
    "    - fit and predict methods used by NeuralForecast.core class. <br>\n",
    "    - sampling and wrangling methods to sequential windows. <br>\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 h,\n",
    "                 input_size,\n",
    "                 inference_input_size,\n",
    "                 loss,\n",
    "                 valid_loss,\n",
    "                 learning_rate,\n",
    "                 max_steps,\n",
    "                 val_check_steps,\n",
    "                 batch_size,\n",
    "                 valid_batch_size,\n",
    "                 scaler_type='robust',\n",
    "                 num_lr_decays=0,\n",
    "                 early_stop_patience_steps=-1,\n",
    "                 futr_exog_list=None,\n",
    "                 hist_exog_list=None,\n",
    "                 stat_exog_list=None,\n",
    "                 num_workers_loader=0,\n",
    "                 drop_last_loader=False,\n",
    "                 random_seed=1, \n",
    "                 alias=None,\n",
    "                 **trainer_kwargs):\n",
    "        super(BaseRecurrent, self).__init__()\n",
    "\n",
    "        self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n",
    "        self.random_seed = random_seed\n",
    "        pl.seed_everything(self.random_seed, workers=True)\n",
    "\n",
    "        # Padder to complete train windows, \n",
    "        # example y=[1,2,3,4,5] h=3 -> last y_output = [5,0,0]\n",
    "        self.h = h\n",
    "        self.input_size = input_size\n",
    "        self.inference_input_size = inference_input_size\n",
    "        self.padder = nn.ConstantPad1d(padding=(0, self.h), value=0)\n",
    "\n",
    "        # Loss\n",
    "        self.loss = loss\n",
    "        if valid_loss is None:\n",
    "            self.valid_loss = loss\n",
    "        else:\n",
    "            self.valid_loss = valid_loss\n",
    "        self.train_trajectories = []\n",
    "        self.valid_trajectories = []\n",
    "\n",
    "        if str(type(self.loss)) == \"<class 'neuralforecast.losses.pytorch.DistributionLoss'>\" and\\\n",
    "            self.loss.distribution=='Bernoulli':\n",
    "                raise Exception('Temporal Classification not yet available for Recurrent-based models')\n",
    "\n",
    "        # Valid batch_size\n",
    "        self.batch_size = batch_size\n",
    "        if valid_batch_size is None:\n",
    "            self.valid_batch_size = batch_size\n",
    "        else:\n",
    "            self.valid_batch_size = valid_batch_size\n",
    "\n",
    "        # Optimization\n",
    "        self.learning_rate = learning_rate\n",
    "        self.max_steps = max_steps\n",
    "        self.num_lr_decays = num_lr_decays\n",
    "        self.lr_decay_steps = max(max_steps // self.num_lr_decays, 1) if self.num_lr_decays > 0 else 10e7\n",
    "        self.early_stop_patience_steps = early_stop_patience_steps\n",
    "        self.val_check_steps = val_check_steps\n",
    "\n",
    "        # Variables\n",
    "        self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
    "        self.hist_exog_list = list(hist_exog_list) if hist_exog_list is not None else []\n",
    "        self.stat_exog_list = list(stat_exog_list) if stat_exog_list is not None else []\n",
    "\n",
    "        # Scaler\n",
    "        self.scaler = TemporalNorm(scaler_type=scaler_type, dim=-1,  # Time dimension is -1.\n",
    "                        num_features=1+len(self.hist_exog_list)+len(self.futr_exog_list))        \n",
    "\n",
    "        # Fit arguments\n",
    "        self.val_size = 0\n",
    "        self.test_size = 0\n",
    "\n",
    "        ## Trainer arguments ##\n",
    "        # Max steps, validation steps and check_val_every_n_epoch\n",
    "        trainer_kwargs = {**trainer_kwargs,\n",
    "                          **{'max_steps': max_steps}}\n",
    "\n",
    "        if 'max_epochs' in trainer_kwargs.keys():\n",
    "            raise Exception('max_epochs is deprecated, use max_steps instead.')\n",
    "\n",
    "        # Callbacks\n",
    "        if 'callbacks' not in trainer_kwargs and self.early_stop_patience_steps > 0:\n",
    "            trainer_kwargs['callbacks'] = [\n",
    "                EarlyStopping(\n",
    "                    monitor='ptl/val_loss', patience=self.early_stop_patience_steps\n",
    "                )\n",
    "            ]\n",
    "\n",
    "        # Add GPU accelerator if available\n",
    "        if trainer_kwargs.get('accelerator', None) is None:\n",
    "            if torch.cuda.is_available():\n",
    "                trainer_kwargs['accelerator'] = \"gpu\"\n",
    "        if trainer_kwargs.get('devices', None) is None:\n",
    "            if torch.cuda.is_available():\n",
    "                trainer_kwargs['devices'] = -1\n",
    "\n",
    "        # Avoid saturating local memory, disabled fit model checkpoints\n",
    "        if trainer_kwargs.get('enable_checkpointing', None) is None:\n",
    "            trainer_kwargs['enable_checkpointing'] = False\n",
    "\n",
    "        self.trainer_kwargs = trainer_kwargs\n",
    "\n",
    "        # DataModule arguments\n",
    "        self.num_workers_loader = num_workers_loader\n",
    "        self.drop_last_loader = drop_last_loader\n",
    "        # used by on_validation_epoch_end hook\n",
    "        self.validation_step_outputs = []\n",
    "        self.alias = alias\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return type(self).__name__ if self.alias is None else self.alias\n",
    "\n",
    "    def on_fit_start(self):\n",
    "        torch.manual_seed(self.random_seed)\n",
    "        np.random.seed(self.random_seed)\n",
    "        random.seed(self.random_seed)\n",
    "        \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
    "        scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n",
    "                                                                  step_size=self.lr_decay_steps,\n",
    "                                                                  gamma=0.5),\n",
    "                     'frequency': 1,\n",
    "                     'interval': 'step'}\n",
    "        return {'optimizer': optimizer, 'lr_scheduler': scheduler}\n",
    "\n",
    "    def _get_temporal_exogenous_cols(self, temporal_cols):\n",
    "        temporal_data_cols = list(set(temporal_cols.tolist()) &\\\n",
    "                                  set(self.hist_exog_list + self.futr_exog_list))\n",
    "        return temporal_data_cols\n",
    "\n",
    "    def _normalization(self, batch, val_size=0, test_size=0):\n",
    "        temporal = batch['temporal'] # B, C, T\n",
    "        temporal_cols = batch['temporal_cols'].copy()\n",
    "        y_idx = batch['y_idx']\n",
    "\n",
    "        # Separate data and mask\n",
    "        temporal_data_cols = self._get_temporal_exogenous_cols(temporal_cols=temporal_cols)\n",
    "        temporal_idxs = get_indexer_raise_missing(temporal_cols, temporal_data_cols)\n",
    "        temporal_idxs = np.append(y_idx, temporal_idxs)\n",
    "        temporal_data = temporal[:, temporal_idxs, :]\n",
    "        temporal_mask = temporal[:, temporal_cols.get_loc('available_mask'), :].clone()\n",
    "\n",
    "        # Remove validation and test set to prevent leakeage\n",
    "        if val_size + test_size > 0:\n",
    "            cutoff = val_size + test_size\n",
    "            temporal_mask[:, -cutoff:] = 0\n",
    "\n",
    "        # Normalize. self.scaler stores the shift and scale for inverse transform\n",
    "        temporal_mask = temporal_mask.unsqueeze(1) # Add channel dimension for scaler.transform.\n",
    "        temporal_data = self.scaler.transform(x=temporal_data, mask=temporal_mask)\n",
    "\n",
    "        # Replace values in windows dict\n",
    "        temporal[:, temporal_idxs, :] = temporal_data\n",
    "        batch['temporal'] = temporal\n",
    "\n",
    "        return batch\n",
    "\n",
    "    def _inv_normalization(self, y_hat, temporal_cols, y_idx):\n",
    "        # Receives window predictions [B, seq_len, H, output]\n",
    "        # Broadcasts outputs and inverts normalization\n",
    "\n",
    "        # Get 'y' scale and shift, and add W dimension\n",
    "        y_loc = self.scaler.x_shift[:, [y_idx], 0].flatten() #[B,C,T] -> [B]        \n",
    "        y_scale = self.scaler.x_scale[:, [y_idx], 0].flatten() #[B,C,T] -> [B]\n",
    "\n",
    "        # Expand scale and shift to y_hat dimensions\n",
    "        y_loc = y_loc.view(*y_loc.shape, *(1,)*(y_hat.ndim-1))#.expand(y_hat)        \n",
    "        y_scale = y_scale.view(*y_scale.shape, *(1,)*(y_hat.ndim-1))#.expand(y_hat)\n",
    "\n",
    "        y_hat = self.scaler.inverse_transform(z=y_hat, x_scale=y_scale, x_shift=y_loc)\n",
    "\n",
    "        return y_hat, y_loc, y_scale\n",
    "\n",
    "    def _create_windows(self, batch, step):\n",
    "        temporal = batch['temporal']\n",
    "        temporal_cols = batch['temporal_cols']\n",
    "\n",
    "        if step == 'train':\n",
    "            if self.val_size + self.test_size > 0:\n",
    "                cutoff = -self.val_size - self.test_size\n",
    "                temporal = temporal[:, :, :cutoff]\n",
    "            temporal = self.padder(temporal)\n",
    "\n",
    "            # Truncate batch to shorter time-series \n",
    "            av_condition = torch.nonzero(torch.min(temporal[:, temporal_cols.get_loc('available_mask')], axis=0).values)\n",
    "            min_time_stamp = int(av_condition.min())\n",
    "            \n",
    "            available_ts = temporal.shape[-1] - min_time_stamp\n",
    "            if available_ts < 1 + self.h:\n",
    "                raise Exception(\n",
    "                    'Time series too short for given input and output size. \\n'\n",
    "                    f'Available timestamps: {available_ts}'\n",
    "                )\n",
    "\n",
    "            temporal = temporal[:, :, min_time_stamp:]\n",
    "\n",
    "        if step == 'val':\n",
    "            if self.test_size > 0:\n",
    "                temporal = temporal[:, :, :-self.test_size]\n",
    "            temporal = self.padder(temporal)\n",
    "\n",
    "        if step == 'predict':\n",
    "            if (self.test_size == 0) and (len(self.futr_exog_list)==0):\n",
    "                temporal = self.padder(temporal)\n",
    "\n",
    "            # Test size covers all data, pad left one timestep with zeros\n",
    "            if temporal.shape[-1] == self.test_size:\n",
    "                padder_left = nn.ConstantPad1d(padding=(1, 0), value=0)\n",
    "                temporal = padder_left(temporal)\n",
    "\n",
    "        # Parse batch\n",
    "        window_size = 1 + self.h # 1 for current t and h for future\n",
    "        windows = temporal.unfold(dimension=-1,\n",
    "                                  size=window_size,\n",
    "                                  step=1)\n",
    "\n",
    "        # Truncated backprogatation/inference (shorten sequence where RNNs unroll)\n",
    "        n_windows = windows.shape[2]\n",
    "        input_size = -1\n",
    "        if (step == 'train') and (self.input_size>0):\n",
    "            input_size = self.input_size\n",
    "            if (input_size > 0) and (n_windows > input_size):\n",
    "                max_sampleable_time = n_windows-self.input_size+1\n",
    "                start = np.random.choice(max_sampleable_time)\n",
    "                windows = windows[:, :, start:(start+input_size), :]\n",
    "\n",
    "        if (step == 'val') and (self.inference_input_size>0):\n",
    "            cutoff = self.inference_input_size + self.val_size\n",
    "            windows = windows[:, :, -cutoff:, :]\n",
    "\n",
    "        if (step == 'predict') and (self.inference_input_size>0):\n",
    "            cutoff = self.inference_input_size + self.test_size\n",
    "            windows = windows[:, :, -cutoff:, :]\n",
    "        \n",
    "        # [B, C, input_size, 1+H]\n",
    "        windows_batch = dict(temporal=windows,\n",
    "                             temporal_cols=temporal_cols,\n",
    "                             static=batch.get('static', None),\n",
    "                             static_cols=batch.get('static_cols', None))\n",
    "\n",
    "        return windows_batch\n",
    "\n",
    "    def _parse_windows(self, batch, windows):\n",
    "        # [B, C, seq_len, 1+H]\n",
    "        # Filter insample lags from outsample horizon\n",
    "        mask_idx = batch['temporal_cols'].get_loc('available_mask')\n",
    "        y_idx = batch['y_idx']        \n",
    "        insample_y = windows['temporal'][:, y_idx, :, :-self.h]\n",
    "        insample_mask = windows['temporal'][:, mask_idx, :, :-self.h]\n",
    "        outsample_y = windows['temporal'][:, y_idx, :, -self.h:].contiguous()\n",
    "        outsample_mask = windows['temporal'][:, mask_idx, :, -self.h:].contiguous()\n",
    "\n",
    "        # Filter historic exogenous variables\n",
    "        if len(self.hist_exog_list):\n",
    "            hist_exog_idx = get_indexer_raise_missing(windows['temporal_cols'], self.hist_exog_list)\n",
    "            hist_exog = windows['temporal'][:, hist_exog_idx, :, :-self.h]\n",
    "        else:\n",
    "            hist_exog = None\n",
    "        \n",
    "        # Filter future exogenous variables\n",
    "        if len(self.futr_exog_list):\n",
    "            futr_exog_idx = get_indexer_raise_missing(windows['temporal_cols'], self.futr_exog_list)\n",
    "            futr_exog = windows['temporal'][:, futr_exog_idx, :, :]\n",
    "        else:\n",
    "            futr_exog = None\n",
    "        # Filter static variables\n",
    "        if len(self.stat_exog_list):\n",
    "            static_idx = get_indexer_raise_missing(windows['static_cols'], self.stat_exog_list)\n",
    "            stat_exog = windows['static'][:, static_idx]\n",
    "        else:\n",
    "            stat_exog = None\n",
    "\n",
    "        return insample_y, insample_mask, outsample_y, outsample_mask, \\\n",
    "               hist_exog, futr_exog, stat_exog\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # Create and normalize windows [Ws, L+H, C]\n",
    "        batch = self._normalization(batch, val_size=self.val_size, test_size=self.test_size)\n",
    "        windows = self._create_windows(batch, step='train')\n",
    "\n",
    "        # Parse windows\n",
    "        insample_y, insample_mask, outsample_y, outsample_mask, \\\n",
    "               hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "\n",
    "        windows_batch = dict(insample_y=insample_y, # [B, seq_len, 1]\n",
    "                             insample_mask=insample_mask, # [B, seq_len, 1]\n",
    "                             futr_exog=futr_exog, # [B, F, seq_len, 1+H]\n",
    "                             hist_exog=hist_exog, # [B, C, seq_len]\n",
    "                             stat_exog=stat_exog) # [B, S]\n",
    "\n",
    "        # Model predictions\n",
    "        output = self(windows_batch) # tuple([B, seq_len, H, output])\n",
    "        if self.loss.is_distribution_output:\n",
    "            outsample_y, y_loc, y_scale = self._inv_normalization(y_hat=outsample_y,\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=batch['y_idx'])\n",
    "            B = output[0].size()[0]\n",
    "            T = output[0].size()[1]\n",
    "            H = output[0].size()[2]\n",
    "            output = [arg.view(-1, *(arg.size()[2:])) for arg in output]\n",
    "            outsample_y = outsample_y.view(B*T,H)\n",
    "            outsample_mask = outsample_mask.view(B*T,H)\n",
    "            y_loc = y_loc.repeat_interleave(repeats=T, dim=0).squeeze(-1)\n",
    "            y_scale = y_scale.repeat_interleave(repeats=T, dim=0).squeeze(-1)\n",
    "            distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
    "            loss = self.loss(y=outsample_y, distr_args=distr_args, mask=outsample_mask)\n",
    "        else:\n",
    "            loss = self.loss(y=outsample_y, y_hat=output, mask=outsample_mask)\n",
    "\n",
    "        if torch.isnan(loss):\n",
    "            print('Model Parameters', self.hparams)\n",
    "            print('insample_y', torch.isnan(insample_y).sum())\n",
    "            print('outsample_y', torch.isnan(outsample_y).sum())\n",
    "            print('output', torch.isnan(output).sum())\n",
    "            raise Exception('Loss is NaN, training stopped.')\n",
    "\n",
    "        self.log('train_loss', loss, batch_size=self.batch_size, prog_bar=True, on_epoch=True)\n",
    "        self.train_trajectories.append((self.global_step, float(loss)))\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        if self.val_size == 0:\n",
    "            return np.nan\n",
    "\n",
    "        # Create and normalize windows [Ws, L+H, C]\n",
    "        batch = self._normalization(batch, val_size=self.val_size, test_size=self.test_size)\n",
    "        windows = self._create_windows(batch, step='val')\n",
    "        y_idx = batch['y_idx']\n",
    "\n",
    "        # Parse windows\n",
    "        insample_y, insample_mask, outsample_y, outsample_mask, \\\n",
    "               hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "\n",
    "        windows_batch = dict(insample_y=insample_y, # [B, seq_len, 1]\n",
    "                             insample_mask=insample_mask, # [B, seq_len, 1]\n",
    "                             futr_exog=futr_exog, # [B, F, seq_len, 1+H]\n",
    "                             hist_exog=hist_exog, # [B, C, seq_len]\n",
    "                             stat_exog=stat_exog) # [B, S]\n",
    "\n",
    "        # Remove train y_hat (+1 and -1 for padded last window with zeros)\n",
    "        # tuple([B, seq_len, H, output]) -> tuple([B, validation_size, H, output])\n",
    "        val_windows = (self.val_size) + 1\n",
    "        outsample_y = outsample_y[:, -val_windows:-1, :]\n",
    "        outsample_mask = outsample_mask[:, -val_windows:-1, :]        \n",
    "\n",
    "        # Model predictions\n",
    "        output = self(windows_batch) # tuple([B, seq_len, H, output])\n",
    "        if self.loss.is_distribution_output:\n",
    "            output = [arg[:, -val_windows:-1] for arg in output]\n",
    "            outsample_y, y_loc, y_scale = self._inv_normalization(y_hat=outsample_y,\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=y_idx)\n",
    "            B = output[0].size()[0]\n",
    "            T = output[0].size()[1]\n",
    "            H = output[0].size()[2]\n",
    "            output = [arg.reshape(-1, *(arg.size()[2:])) for arg in output]\n",
    "            outsample_y = outsample_y.reshape(B*T,H)\n",
    "            outsample_mask = outsample_mask.reshape(B*T,H)\n",
    "            y_loc = y_loc.repeat_interleave(repeats=T, dim=0).squeeze(-1)\n",
    "            y_scale = y_scale.repeat_interleave(repeats=T, dim=0).squeeze(-1)\n",
    "            distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
    "            _, sample_mean, quants  = self.loss.sample(distr_args=distr_args)\n",
    "\n",
    "            if str(type(self.valid_loss)) in\\\n",
    "                [\"<class 'neuralforecast.losses.pytorch.sCRPS'>\", \"<class 'neuralforecast.losses.pytorch.MQLoss'>\"]:\n",
    "                output = quants\n",
    "            elif str(type(self.valid_loss)) in [\"<class 'neuralforecast.losses.pytorch.relMSE'>\"]:\n",
    "                output = torch.unsqueeze(sample_mean, dim=-1) # [N,H,1] -> [N,H]\n",
    "            \n",
    "        else:\n",
    "            output = output[:, -val_windows:-1, :]\n",
    "\n",
    "        # Validation Loss evaluation\n",
    "        if self.valid_loss.is_distribution_output:\n",
    "            valid_loss = self.valid_loss(y=outsample_y, distr_args=distr_args, mask=outsample_mask)\n",
    "        else:\n",
    "            outsample_y, _, _ = self._inv_normalization(y_hat=outsample_y, temporal_cols=batch['temporal_cols'], y_idx=y_idx)\n",
    "            output, _, _      = self._inv_normalization(y_hat=output, temporal_cols=batch['temporal_cols'], y_idx=y_idx)\n",
    "            valid_loss = self.valid_loss(y=outsample_y, y_hat=output, mask=outsample_mask)\n",
    "\n",
    "        if torch.isnan(valid_loss):\n",
    "            raise Exception('Loss is NaN, training stopped.')\n",
    "\n",
    "        self.log('valid_loss', valid_loss, batch_size=self.batch_size, prog_bar=True, on_epoch=True)\n",
    "        self.validation_step_outputs.append(valid_loss)\n",
    "        return valid_loss\n",
    "\n",
    "    def on_validation_epoch_end(self):\n",
    "        if self.val_size == 0:\n",
    "            return\n",
    "        avg_loss = torch.stack(self.validation_step_outputs).mean()\n",
    "        self.log(\"ptl/val_loss\", avg_loss, batch_size=self.batch_size)\n",
    "        self.valid_trajectories.append((self.global_step, float(avg_loss)))\n",
    "        self.validation_step_outputs.clear() # free memory (compute `avg_loss` per epoch) \n",
    "\n",
    "    def predict_step(self, batch, batch_idx):\n",
    "        # Create and normalize windows [Ws, L+H, C]\n",
    "        batch = self._normalization(batch, val_size=0, test_size=self.test_size)\n",
    "        windows = self._create_windows(batch, step='predict')\n",
    "        y_idx = batch['y_idx']\n",
    "\n",
    "        # Parse windows\n",
    "        insample_y, insample_mask, _, _, \\\n",
    "               hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "\n",
    "        windows_batch = dict(insample_y=insample_y, # [B, seq_len, 1]\n",
    "                             insample_mask=insample_mask, # [B, seq_len, 1]\n",
    "                             futr_exog=futr_exog, # [B, F, seq_len, 1+H]\n",
    "                             hist_exog=hist_exog, # [B, C, seq_len]\n",
    "                             stat_exog=stat_exog) # [B, S]\n",
    "\n",
    "        # Model Predictions\n",
    "        output = self(windows_batch) # tuple([B, seq_len, H], ...)\n",
    "        if self.loss.is_distribution_output:\n",
    "            _, y_loc, y_scale = self._inv_normalization(y_hat=output[0],\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=y_idx)\n",
    "            B = output[0].size()[0]\n",
    "            T = output[0].size()[1]\n",
    "            H = output[0].size()[2]\n",
    "            output = [arg.reshape(-1, *(arg.size()[2:])) for arg in output]\n",
    "            y_loc = y_loc.repeat_interleave(repeats=T, dim=0).squeeze(-1)\n",
    "            y_scale = y_scale.repeat_interleave(repeats=T, dim=0).squeeze(-1)\n",
    "            distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
    "            _, sample_mean, quants = self.loss.sample(distr_args=distr_args)\n",
    "            y_hat = torch.concat((sample_mean, quants), axis=2)\n",
    "            y_hat = y_hat.view(B, T, H, -1)\n",
    "\n",
    "            if self.loss.return_params:\n",
    "                distr_args = torch.stack(distr_args, dim=-1)\n",
    "                distr_args = torch.reshape(distr_args, (B, T, H, -1))\n",
    "                y_hat = torch.concat((y_hat, distr_args), axis=3)\n",
    "        else:\n",
    "            y_hat, _, _ = self._inv_normalization(y_hat=output,\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=y_idx)\n",
    "        return y_hat\n",
    "\n",
    "    def fit(self, dataset, val_size=0, test_size=0, random_seed=None):\n",
    "        \"\"\" Fit.\n",
    "\n",
    "        The `fit` method, optimizes the neural network's weights using the\n",
    "        initialization parameters (`learning_rate`, `batch_size`, ...)\n",
    "        and the `loss` function as defined during the initialization. \n",
    "        Within `fit` we use a PyTorch Lightning `Trainer` that\n",
    "        inherits the initialization's `self.trainer_kwargs`, to customize\n",
    "        its inputs, see [PL's trainer arguments](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).\n",
    "\n",
    "        The method is designed to be compatible with SKLearn-like classes\n",
    "        and in particular to be compatible with the StatsForecast library.\n",
    "\n",
    "        By default the `model` is not saving training checkpoints to protect \n",
    "        disk memory, to get them change `enable_checkpointing=True` in `__init__`.        \n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
    "        `val_size`: int, validation size for temporal cross-validation.<br>\n",
    "        `test_size`: int, test size for temporal cross-validation.<br>\n",
    "        `random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.<br>\n",
    "        \"\"\"\n",
    "\n",
    "        # Check exogenous variables are contained in dataset\n",
    "        temporal_cols = set(dataset.temporal_cols.tolist())\n",
    "        static_cols = set(dataset.static_cols.tolist() if dataset.static_cols is not None else [])\n",
    "        if len(set(self.hist_exog_list) - temporal_cols)>0:\n",
    "            raise Exception(f'{set(self.hist_exog_list) - temporal_cols} historical exogenous variables not found in input dataset')\n",
    "        if len(set(self.futr_exog_list) - temporal_cols)>0:\n",
    "            raise Exception(f'{set(self.futr_exog_list) - temporal_cols} future exogenous variables not found in input dataset')\n",
    "        if len(set(self.stat_exog_list) - static_cols)>0:\n",
    "            raise Exception(f'{set(self.stat_exog_list) - static_cols} static exogenous variables not found in input dataset')\n",
    "        \n",
    "        # Restart random seed\n",
    "        if random_seed is None:\n",
    "            random_seed = self.random_seed\n",
    "        torch.manual_seed(random_seed)\n",
    "\n",
    "        self.val_size = val_size\n",
    "        self.test_size = test_size\n",
    "        datamodule = TimeSeriesDataModule(\n",
    "            dataset=dataset, \n",
    "            batch_size=self.batch_size,\n",
    "            valid_batch_size=self.valid_batch_size,\n",
    "            num_workers=self.num_workers_loader,\n",
    "            drop_last=self.drop_last_loader\n",
    "        )\n",
    "\n",
    "        if self.val_check_steps > self.max_steps:\n",
    "            warnings.warn('val_check_steps is greater than max_steps, \\\n",
    "                    setting val_check_steps to max_steps')\n",
    "        val_check_interval = min(self.val_check_steps, self.max_steps)\n",
    "        self.trainer_kwargs['val_check_interval'] = int(val_check_interval)\n",
    "        self.trainer_kwargs['check_val_every_n_epoch'] = None\n",
    "\n",
    "        trainer = pl.Trainer(**self.trainer_kwargs)\n",
    "        trainer.fit(self, datamodule=datamodule)\n",
    "        return trainer\n",
    "\n",
    "    def predict(self, dataset, step_size=1,\n",
    "                random_seed=None, **data_module_kwargs):\n",
    "        \"\"\" Predict.\n",
    "\n",
    "        Neural network prediction with PL's `Trainer` execution of `predict_step`.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
    "        `step_size`: int=1, Step size between each window.<br>\n",
    "        `random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.<br>\n",
    "        `**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).\n",
    "        \"\"\"\n",
    "        \n",
    "        # Check exogenous variables are contained in dataset\n",
    "        temporal_cols = set(dataset.temporal_cols.tolist())\n",
    "        static_cols = set(dataset.static_cols.tolist() if dataset.static_cols is not None else [])\n",
    "        if len(set(self.hist_exog_list) - temporal_cols)>0:\n",
    "            raise Exception(f'{set(self.hist_exog_list) - temporal_cols} historical exogenous variables not found in input dataset')\n",
    "        if len(set(self.futr_exog_list) - temporal_cols)>0:\n",
    "            raise Exception(f'{set(self.futr_exog_list) - temporal_cols} future exogenous variables not found in input dataset')\n",
    "        if len(set(self.stat_exog_list) - static_cols)>0:\n",
    "            raise Exception(f'{set(self.stat_exog_list) - static_cols} static exogenous variables not found in input dataset')\n",
    "        \n",
    "        # Restart random seed\n",
    "        if random_seed is None:\n",
    "            random_seed = self.random_seed\n",
    "        torch.manual_seed(random_seed)\n",
    "\n",
    "        if step_size > 1:\n",
    "            raise Exception('Recurrent models do not support step_size > 1')\n",
    "\n",
    "        # fcsts (window, batch, h)\n",
    "        # Protect when case of multiple gpu. PL does not support return preds with multiple gpu.\n",
    "        pred_trainer_kwargs = self.trainer_kwargs.copy()\n",
    "        if (pred_trainer_kwargs.get('accelerator', None) == \"gpu\") and (torch.cuda.device_count() > 1):\n",
    "            pred_trainer_kwargs['devices'] = [0]\n",
    "\n",
    "        trainer = pl.Trainer(**pred_trainer_kwargs)\n",
    "\n",
    "        datamodule = TimeSeriesDataModule(\n",
    "            dataset=dataset,\n",
    "            valid_batch_size=self.valid_batch_size,\n",
    "            num_workers=self.num_workers_loader,\n",
    "            **data_module_kwargs\n",
    "        )\n",
    "        fcsts = trainer.predict(self, datamodule=datamodule)\n",
    "        if self.test_size > 0:\n",
    "            # Remove warmup windows (from train and validation)\n",
    "            # [N,T,H,output], avoid indexing last dim for univariate output compatibility\n",
    "            fcsts = torch.vstack([fcst[:, -(1+self.test_size-self.h):,:] for fcst in fcsts])\n",
    "            fcsts = fcsts.numpy().flatten()\n",
    "            fcsts = fcsts.reshape(-1, len(self.loss.output_names))\n",
    "        else:\n",
    "            fcsts = torch.vstack([fcst[:,-1:,:] for fcst in fcsts]).numpy().flatten()\n",
    "            fcsts = fcsts.reshape(-1, len(self.loss.output_names))\n",
    "        return fcsts\n",
    "\n",
    "    def set_test_size(self, test_size):\n",
    "        self.test_size = test_size\n",
    "\n",
    "    def get_test_size(self):\n",
    "        return self.test_size\n",
    "\n",
    "    def save(self, path):\n",
    "        \"\"\" BaseRecurrent.save\n",
    "\n",
    "        Save the fitted model to disk.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `path`: str, path to save the model.<br>\n",
    "        \"\"\"\n",
    "        self.trainer.save_checkpoint(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(BaseRecurrent, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(BaseRecurrent.fit, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(BaseRecurrent.predict, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from neuralforecast.losses.pytorch import MAE\n",
    "from neuralforecast.utils import AirPassengersDF\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesDataModule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# add h=0,1 unit test for _parse_windows \n",
    "# Declare batch\n",
    "AirPassengersDF['x'] = np.array(len(AirPassengersDF))\n",
    "AirPassengersDF['x2'] = np.array(len(AirPassengersDF)) * 2\n",
    "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=AirPassengersDF)\n",
    "data = TimeSeriesDataModule(dataset=dataset, batch_size=1, drop_last=True)\n",
    "\n",
    "train_loader =  data.train_dataloader()\n",
    "batch = next(iter(train_loader))\n",
    "\n",
    "# Test that hist_exog_list and futr_exog_list correctly filter data that is sent to scaler.\n",
    "baserecurrent = BaseRecurrent(h=12,\n",
    "                              input_size=117,\n",
    "                              hist_exog_list=['x', 'x2'],\n",
    "                              futr_exog_list=['x'],\n",
    "                              loss=MAE(),\n",
    "                              valid_loss=MAE(),\n",
    "                              learning_rate=0.001,\n",
    "                              max_steps=1,\n",
    "                              val_check_steps=0,\n",
    "                              batch_size=1,\n",
    "                              valid_batch_size=1,\n",
    "                              windows_batch_size=10,\n",
    "                              inference_input_size=2,\n",
    "                              start_padding_enabled=True)\n",
    "\n",
    "windows = baserecurrent._create_windows(batch, step='train')\n",
    "\n",
    "temporal_cols = windows['temporal_cols'].copy() # B, L+H, C\n",
    "temporal_data_cols = baserecurrent._get_temporal_exogenous_cols(temporal_cols=temporal_cols)\n",
    "\n",
    "test_eq(set(temporal_data_cols), set(['x', 'x2']))\n",
    "test_eq(windows['temporal'].shape, torch.Size([1,len(['y', 'x', 'x2', 'available_mask']),117,12+1]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
